對抗生成網路進階應用

林嶔 (Lin, Chin)

Lesson 16

解決對抗生成網路的訓練問題(1)

– 在當初的神經網路我們為了訓練一個很深的網路,常常需要對超參數做大量的嘗試修正,有時候還需要自編碼器的輔助,直到Residual Learning結束了這一切。

– 我們回顧一下我們的Cross-entropy損失函數以及他的導函數:

\[ \begin{align} CE(y, p) & = \frac{{1}}{n}\sum \limits_{i=1}^{n} -\left(y_{i} \cdot log(p_{i}) + (1-y_{i}) \cdot log(1-p_{i})\right) \\ \frac{\partial}{\partial p}CE(y, p) & = \frac{p-y}{p(1-p)} \end{align} \]

\[ \begin{align} S(x) & =\frac{1}{1+e^{-x}} \\ \frac{\partial}{\partial x}S(x) & = S(x)(1-S(x)) \end{align} \]

解決對抗生成網路的訓練問題(2)

– 但對於Generator而言呢,他這時候需要大量的更新試圖重新騙過Discriminator,但這時候他的梯度將是…

\[ \begin{align} \lim_{p \rightarrow 1} CE(0, p) & = - log(1-p) \\ \frac{\partial}{\partial p} \lim_{p \rightarrow 1} CE(0, p) & = \frac{p}{p(1-p)} \end{align} \]

\[ \begin{align} \frac{\partial}{\partial x} \lim_{S(x) \rightarrow 1} CE(0, p) & = \frac{S(x)^2(1-S(x))}{S(x)(1-S(x))} \end{align} \]

– 除此之外,對於Discriminator以及一般網路隨著\(p \rightarrow 1\)的過程中,他的梯度會慢慢變小,這有點學習率遞減的概念,但此時對於Generator而言卻是學習率遞增的,而學習率不見得比較大就收斂比較快!

解決對抗生成網路的訓練問題(3)

  1. 去除Sigmoid函數,因為它會在極端狀況下導致近似值計算失準。

  2. 無論Discriminator跟Generator誰佔優勢,選擇一個平滑的損失函數來描述目前的競賽狀況。

F16_2

註:數值越小代表越好!

解決對抗生成網路的訓練問題(4)

F16_1

\[ \begin{align} loss(y, x) & = (1-y)x - yx \end{align} \]

實作WGAN(1)

library(imager)
library(magrittr)
library(mxnet)

my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.x <- val[-1,]
                                    batch_size <- ncol(val.x)
                                    val.x <- val.x / 255 # Important        
                                    dim(val.x) <- c(28, 28, 1, batch_size)
                                    val.x <- mx.nd.array(val.x)
                                    
                                    digit.real <- mx.nd.array(val[1,])
                                    digit.real <- mx.nd.one.hot(indices = digit.real, depth = 10)
                                    digit.real <- mx.nd.reshape(data = digit.real, shape = c(1, 1, -1, batch_size))
                                      
                                    digit.fake <- mx.nd.array(sample(0:9, size = batch_size, replace = TRUE))
                                    digit.fake <- mx.nd.one.hot(indices = digit.fake, depth = 10)
                                    digit.fake <- mx.nd.reshape(data = digit.fake, shape = c(1, 1, -1, batch_size))

                                    rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
                                    rand <- array(rand, dim = c(1, 1, 10, batch_size))
                                    rand <- mx.nd.array(rand)
                                    
                                    label.real <- array(runif(10, 0, 0), dim = c(1, 1, 1, batch_size))
                                    label.real <- mx.nd.array(label.real)
                                    label.fake <- array(runif(10, 1, 1), dim = c(1, 1, 1, batch_size))
                                    label.fake <- mx.nd.array(label.fake)
                                    label.gen <- array(rep(0, 10), dim = c(1, 1, 1, batch_size))
                                    label.gen <- mx.nd.array(label.gen)
                                    
                                    list(noise = rand, img = val.x, digit.fake = digit.fake, digit.real = digit.real, label.fake = label.fake, label.real = label.real, label.gen = label.gen)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter <- my_iterator_func(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)

實作WGAN(2)

gen_data <- mx.symbol.Variable('data')
gen_digit <- mx.symbol.Variable('digit')

gen_concat <- mx.symbol.concat(data = list(gen_data, gen_digit), num.args = 2, dim = 1, name = "gen_concat")

gen_deconv1 <- mx.symbol.Deconvolution(data = gen_concat, kernel = c(4, 4), stride = c(2, 2), num_filter = 256, name = 'gen_deconv1')
gen_bn1 <- mx.symbol.BatchNorm(data = gen_deconv1, fix_gamma = TRUE, name = 'gen_bn1')
gen_relu1 <- mx.symbol.Activation(data = gen_bn1, act_type = "relu", name = 'gen_relu1')

gen_deconv2 <- mx.symbol.Deconvolution(data = gen_relu1, kernel = c(3, 3), stride = c(2, 2), pad = c(1, 1), num_filter = 128, name = 'gen_deconv2')
gen_bn2 <- mx.symbol.BatchNorm(data = gen_deconv2, fix_gamma = TRUE, name = 'gen_bn2')
gen_relu2 <- mx.symbol.Activation(data = gen_bn2, act_type = "relu", name = 'gen_relu2')

gen_deconv3 <- mx.symbol.Deconvolution(data = gen_relu2, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 64, name = 'gen_deconv3')
gen_bn3 <- mx.symbol.BatchNorm(data = gen_deconv3, fix_gamma = TRUE, name = 'gen_bn3')
gen_relu3 <- mx.symbol.Activation(data = gen_bn3, act_type = "relu", name = 'gen_relu3')

gen_deconv4 <- mx.symbol.Deconvolution(data = gen_relu3, kernel = c(4, 4), stride = c(2, 2), pad = c(1, 1), num_filter = 1, name = 'gen_deconv4')
gen_pred <- mx.symbol.Activation(data = gen_deconv4, act_type = "sigmoid", name = 'gen_pred')
dis_img <- mx.symbol.Variable('img')
dis_digit <- mx.symbol.Variable("digit")
dis_label <- mx.symbol.Variable('label')

dis_concat <- mx.symbol.broadcast_mul(lhs = dis_img, rhs = dis_digit, name = 'dis_concat')

dis_conv1 <- mx.symbol.Convolution(data = dis_concat, kernel = c(3, 3), num_filter = 24, no.bias = TRUE, name = 'dis_conv1')
dis_bn1 <- mx.symbol.BatchNorm(data = dis_conv1, fix_gamma = TRUE, name = 'dis_bn1')
dis_relu1 <- mx.symbol.LeakyReLU(data = dis_bn1, act_type = "leaky", slope = 0.2, name = "dis_relu1")
dis_pool1 <- mx.symbol.Pooling(data = dis_relu1, pool_type = "avg", kernel = c(2, 2), stride = c(2, 2), name = 'dis_pool1')

dis_conv2 <- mx.symbol.Convolution(data = dis_pool1, kernel = c(3, 3), stride = c(2, 2), num_filter = 32, no.bias = TRUE, name = 'dis_conv2')
dis_bn2 <- mx.symbol.BatchNorm(data = dis_conv2, fix_gamma = TRUE, name = 'dis_bn2')
dis_relu2 <- mx.symbol.LeakyReLU(data = dis_bn2, act_type = "leaky", slope = 0.2, name = "dis_relu2")

dis_conv3 <- mx.symbol.Convolution(data = dis_relu2, kernel = c(3, 3), num_filter = 64, no.bias = TRUE, name = 'dis_conv3')
dis_bn3 <- mx.symbol.BatchNorm(data = dis_conv3, fix_gamma = TRUE, name = 'dis_bn3')
dis_relu3 <- mx.symbol.LeakyReLU(data = dis_bn3, act_type = "leaky", slope = 0.2, name = "dis_relu3")

dis_conv4 <- mx.symbol.Convolution(data = dis_relu3, kernel = c(4, 4), num_filter = 64, no.bias = TRUE, name = 'dis_conv4')
dis_bn4 <- mx.symbol.BatchNorm(data = dis_conv4, fix_gamma = TRUE, name = 'dis_bn4')
dis_relu4 <- mx.symbol.LeakyReLU(data = dis_bn4, act_type = "leaky", slope = 0.2, name = "dis_relu4")

dis_pred <- mx.symbol.Convolution(data = dis_relu4, kernel = c(1, 1), num_filter = 1, name = 'dis_pred')
w_loss_pos <-  mx.symbol.broadcast_mul(dis_pred, dis_label)
w_loss_neg <-  mx.symbol.broadcast_mul(dis_pred, 1 - dis_label)
w_loss_mean <- mx.symbol.mean(w_loss_neg - w_loss_pos)
w_loss <- mx.symbol.MakeLoss(w_loss_mean, name = 'w_loss')

實作WGAN(3)

gen_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-4, beta1 = 0, beta2 = 0.9, wd = 0)
dis_optimizer <- mx.opt.create(name = "adam", learning.rate = 1e-4, beta1 = 0, beta2 = 0.9, wd = 0)
gen_executor <- mx.simple.bind(symbol = gen_pred,
                               data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32),
                               ctx = mx.cpu(), grad.req = "write")

dis_executor <- mx.simple.bind(symbol = w_loss,
                               img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32),
                               ctx = mx.cpu(), grad.req = "write")
# Initial parameters

mx.set.seed(0)

gen_arg <- mxnet:::mx.model.init.params(symbol = gen_pred,
                                        input.shape = list(data = c(1, 1, 10, 32), digit = c(1, 1, 10, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

dis_arg <- mxnet:::mx.model.init.params(symbol = w_loss,
                                        input.shape = list(img = c(28, 28, 1, 32), digit = c(1, 1, 10, 32), label = c(1, 1, 1, 32)),
                                        output.shape = NULL,
                                        initializer = mxnet:::mx.init.uniform(0.01),
                                        ctx = mx.cpu())

# Update parameters

mx.exec.update.arg.arrays(gen_executor, gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(gen_executor, gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(dis_executor, dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(dis_executor, dis_arg$aux.params, match.name = TRUE)
gen_updater <- mx.opt.get.updater(optimizer = gen_optimizer, weights = gen_executor$ref.arg.arrays)
dis_updater <- mx.opt.get.updater(optimizer = dis_optimizer, weights = dis_executor$ref.arg.arrays)

實作WGAN(4)

set.seed(0)
n.epoch <- 20
w_limit <- 0.1
logger <- list(gen_loss = NULL, dis_real_loss = NULL, dis_fake_loss = NULL)
for (j in 1:n.epoch) {
  
  current_batch <- 0
  my_iter$reset()
  
  while (my_iter$iter.next()) {
    
    my_values <- my_iter$value()
    
    # Generator (forward)
    
    mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = my_values[['noise']], digit = my_values[['digit.fake']]), match.name = TRUE)
    mx.exec.forward(gen_executor, is.train = TRUE)
    gen_pred_output <- gen_executor$ref.outputs[[1]]
    
    # Discriminator (fake)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.fake']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_fake_loss <- c(logger$dis_fake_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Discriminator (real)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = my_values[['img']], digit = my_values[['digit.real']], label = my_values[['label.real']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    dis_update_args <- dis_updater(weight = dis_executor$ref.arg.arrays, grad = dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(dis_executor, dis_update_args, skip.null = TRUE)
    
    logger$dis_real_loss <- c(logger$dis_real_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    # Weight clipping (only for discriminator)
    
    dis_weight_names <- grep('weight', names(dis_executor$ref.arg.arrays), value = TRUE)
    
    
    
    for (k in dis_weight_names) {
      
      current_dis_weight <- dis_executor$ref.arg.arrays[[k]] %>% as.array()
      current_dis_weight_list <- current_dis_weight %>% mx.nd.array() %>%
        mx.nd.broadcast.minimum(., mx.nd.array(w_limit)) %>%
        mx.nd.broadcast.maximum(., mx.nd.array(-w_limit)) %>%
        list()
      names(current_dis_weight_list) <- k
      mx.exec.update.arg.arrays(dis_executor, arg.arrays = current_dis_weight_list, match.name = TRUE)
      
    }
    
    # Generator (backward)
    
    mx.exec.update.arg.arrays(dis_executor, arg.arrays = list(img = gen_pred_output, digit = my_values[['digit.fake']], label = my_values[['label.gen']]), match.name = TRUE)
    mx.exec.forward(dis_executor, is.train = TRUE)
    mx.exec.backward(dis_executor)
    img_grads <- dis_executor$ref.grad.arrays[['img']]
    mx.exec.backward(gen_executor, out_grads = img_grads)
    gen_update_args <- gen_updater(weight = gen_executor$ref.arg.arrays, grad = gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(gen_executor, gen_update_args, skip.null = TRUE)
    
    logger$gen_loss <- c(logger$gen_loss, as.array(dis_executor$ref.outputs[[1]]))
    
    if (current_batch %% 100 == 0) {
      
      # Show current images
      
      current_digits <- my_values[['digit.fake']] %>% as.array() %>% .[,,,1:9] %>% t %>% max.col - 1
      
      par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
      
      for (i in 1:9) {
        img <- as.array(gen_pred_output)[,,,i]
        plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
        rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
        text(0.05, 0.95, current_digits[i], col = 'green', cex = 2)
      }
      
      # Show loss
      
      message('Epoch [', j, '] Batch [', current_batch, '] Generator-loss = ', formatC(tail(logger$gen_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (real) = ', formatC(tail(logger$dis_real_loss, 1), digits = 5, format = 'f'))
      message('Epoch [', j, '] Batch [', current_batch, '] Discriminator-loss (fake) = ', formatC(tail(logger$dis_fake_loss, 1), digits = 5, format = 'f'))
      
    }
    
    current_batch <- current_batch + 1
    
  }
  
  pdf(paste0('result/epoch_', j, '.pdf'), height = 6, width = 6)
  
  current_digits <- my_values[['digit.fake']] %>% as.array() %>% .[,,,1:9] %>% t %>% max.col - 1
  
  par(mfrow = c(3, 3), mar = c(0.1, 0.1, 0.1, 0.1))
  
  for (i in 1:9) {
    img <- as.array(gen_pred_output)[,,,i]
    plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
    rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
    text(0.05, 0.95, current_digits[i], col = 'green', cex = 2)
  }
  
  dev.off()
  
  gen_model <- list()
  gen_model$symbol <- gen_pred
  gen_model$arg.params <- gen_executor$ref.arg.arrays[-c(1:2)]
  gen_model$aux.params <- gen_executor$ref.aux.arrays
  class(gen_model) <- "MXFeedForwardModel"
  
  dis_model <- list()
  dis_model$symbol <- dis_pred
  dis_model$arg.params <- dis_executor$ref.arg.arrays[-c(1:2)]
  dis_model$aux.params <- dis_executor$ref.aux.arrays
  class(dis_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = gen_model, prefix = 'model/cwgen_v1', iteration = j)
  mx.model.save(model = dis_model, prefix = 'model/cwdis_v1', iteration = j)
  
}

F16_3

實作WGAN(5)

range_logger <- logger %>% unlist %>% range

plot(logger$gen_loss, type = 'l', col = 'red', lwd = 0.5, ylim = range_logger, xlab = 'Batch', ylab = 'loss')
lines(1:length(logger$dis_real_loss), logger$dis_real_loss, col = 'blue', lwd = 0.5)
lines(1:length(logger$dis_fake_loss), logger$dis_fake_loss, col = 'darkgreen', lwd = 0.5)
legend('topright', c('Gen', 'Real', 'Fake'), col = c('red', 'blue', 'darkgreen'), lwd = 1)

cwgen_model <- mx.model.load('model/cwgen_v1', 0)

my_predict <- function (model, digits = 0:9) {
  
  batch_size <- length(digits)
  
  gen_executor <- mx.simple.bind(symbol = model$symbol,
                                 data = c(1, 1, 10, batch_size), digit = c(1, 1, 10, batch_size),
                                 ctx = mx.cpu())
  
  mx.exec.update.arg.arrays(gen_executor, model$arg.params, match.name = TRUE)
  mx.exec.update.aux.arrays(gen_executor, model$aux.params, match.name = TRUE)
  
  noise_array <- rnorm(batch_size * 10, mean = 0, sd = 1)
  noise_array <- array(noise_array, dim = c(1, 1, 10, batch_size))
  noise_array <- mx.nd.array(noise_array)
  
  digit_array <- mx.nd.array(digits)
  digit_array <- mx.nd.one.hot(indices = digit_array, depth = 10)
  digit_array <- mx.nd.reshape(data = digit_array, shape = c(1, 1, -1, batch_size))
  
  mx.exec.update.arg.arrays(gen_executor, arg.arrays = list(data = noise_array, digit = digit_array), match.name = TRUE)
  mx.exec.forward(gen_executor, is.train = FALSE)
  gen_pred_output <- gen_executor$ref.outputs[[1]]
  
  return(as.array(gen_pred_output))
  
}
pred_img <- my_predict(model = cwgen_model, digits = 0:9)

par(mfrow = c(2, 5), mar = c(0.1, 0.1, 0.1, 0.1))

for (i in 1:10) {
  img <- pred_img[,,,i]
  plot(NA, xlim = 0:1, ylim = 0:1, xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(t(img)), -0.04, -0.04, 1.04, 1.04, interpolate = FALSE)
}

練習1:實作LSGAN

– 這一系列的GAN,說穿了就是改改損失函數,而LSGAN所使用的是平方誤差函數,損失函數被改為:

\[ \begin{align} loss(y, x) & = (x-y)^2 \end{align} \]

  1. \(a = -1\)\(b = 1\)\(c = 0\)

  2. \(a = 0\)\(b = 1\)\(c = 1\)

練習1答案

loss_diff <- mx.symbol.broadcast_minus(dis_pred, dis_label)
loss_square_diff <- mx.symbol.square(loss_diff)
loss_mean <- mx.symbol.mean(loss_square_diff)
ls_loss <- mx.symbol.MakeLoss(loss_mean, name = 'ls_loss')
my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "data.csv", "data.shape", "batch.size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, data.csv, data.shape, batch.size){
                                    csv_iter <- mx.io.CSVIter(data.csv = data.csv, data.shape = data.shape, batch.size = batch.size)
                                    .self$iter <- csv_iter
                                    .self
                                  },
                                  value = function(){
                                    val <- as.array(.self$iter$value()$data)
                                    val.x <- val[-1,]
                                    batch_size <- ncol(val.x)
                                    val.x <- val.x / 255 # Important        
                                    dim(val.x) <- c(28, 28, 1, batch_size)
                                    val.x <- mx.nd.array(val.x)
                                    
                                    digit.real <- mx.nd.array(val[1,])
                                    digit.real <- mx.nd.one.hot(indices = digit.real, depth = 10)
                                    digit.real <- mx.nd.reshape(data = digit.real, shape = c(1, 1, -1, batch_size))
                                      
                                    digit.fake <- mx.nd.array(sample(0:9, size = batch_size, replace = TRUE))
                                    digit.fake <- mx.nd.one.hot(indices = digit.fake, depth = 10)
                                    digit.fake <- mx.nd.reshape(data = digit.fake, shape = c(1, 1, -1, batch_size))

                                    rand <- rnorm(batch_size * 10, mean = 0, sd = 1)
                                    rand <- array(rand, dim = c(1, 1, 10, batch_size))
                                    rand <- mx.nd.array(rand)
                                    
                                    label.real <- array(rep(0, 10), dim = c(1, 1, 1, batch_size))
                                    label.real <- mx.nd.array(label.real)
                                    label.fake <- array(rep(1, 10), dim = c(1, 1, 1, batch_size))
                                    label.fake <- mx.nd.array(label.fake)
                                    label.gen <- array(rep(1, 10), dim = c(1, 1, 1, batch_size))
                                    label.gen <- mx.nd.array(label.gen)
                                    
                                    list(noise = rand, img = val.x, digit.fake = digit.fake, digit.real = digit.real, label.fake = label.fake, label.real = label.real, label.gen = label.gen)
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

my_iter <- my_iterator_func(iter = NULL,  data.csv = 'data/train_data.csv', data.shape = 785, batch.size = 32)

循環對抗生成網路簡介(1)

– 這是Jun-Yan Zhu、Taesung Park、Phillip Isola與Alexei A. Efros在2017年提出的研究:Unpaired Image-to-Image Translation using Cycle-Consistent Adversarial Networks中所提出的模型

F16_5

– 但比較厲害的是,Cycle GAN所輸入的條件並非是類似於Conditional GAN的結構化條件,而是直接利用一張圖片當作Condition。

循環對抗生成網路簡介(2)

F16_6

– 在許多領域中,我們不可能蒐集到Paired data,舉例來說像是照片與藝術畫的轉換,因此這件事情是非常重要的!除此之外,Paired data的蒐集難度也高上非常多,這也是Cycle GAN所吸引人的地方所在。

F16_4

循環對抗生成網路簡介(3)

– 假定有兩個函數分別是\(G(X)\)負責將\(X\)轉換為\(\hat{Y}\),而另一個函數\(F(Y)\)負責將\(Y\)轉換為\(\hat{X}\),則當兩個函數達到完美狀態時,必須保證\(F(G(X)) = X\)\(G(F(Y)) = Y\)

F16_7

\[ \begin{align} \mbox{cycle consistency loss} & = |F(G(X)) - X| + |G(F(Y)) - Y| \end{align} \]

註:原始論文使用的是L1 loss(如上式),而替換成L2 loss或者是其他損失函數影響不大。

循環對抗生成網路簡介(4)

– 這時候我們要引入兩個Discriminator分別是\(D_x(X)\)以及\(D_y(Y)\),他們分別要識別翻譯出來的圖是真的還是假的,而在Generator跟Discriminator的競合我們將給出一個對抗損失(adversarial loss),這與之前的所有GAN完全一樣,我們當然也能用WGAN的損失函數:

\[ \begin{align} \mbox{adversarial loss for Discriminator(x)} & = D_x(F(Y)) - D_x(X) \\ \mbox{adversarial loss for Discriminator(y)} & = D_y(G(X)) - D_y(Y) \\ \mbox{adversarial loss for Generator(x)} & = - D_x(F(Y)) \\ \mbox{adversarial loss for Generator(y)} & = - D_y(G(X)) \end{align} \]

F16_8

循環對抗生成網路簡介(5)

F16_9

循環對抗生成網路簡介(6)

– 這是原始照片:

library(OpenImageR)
library(jpeg)

photo <- readJPEG('images/header.jpg')
resize_photo <- resizeImage(image = photo,
                            width = 648,
                            height = 256,
                            method = "bilinear")

Show_img <- function (img) {
  par(mai = rep(0, 4))
  plot(NA, xlim = c(0.04, 0.96), ylim = c(0.96, 0.04), xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(img), 0, 1, 1, 0, interpolate = FALSE)
}

Show_img(resize_photo)

library(mxnet)

P2M_gen_model <- mx.model.load(prefix = 'model/P2M_gen_v1', iteration = 0)

my_predict <- function (model, img) {
  
  dim(img) <- c(dim(img), 1)
  
  P2M_executor <- mx.simple.bind(symbol = model$symbol,
                                 P2M_img = dim(img),
                                 ctx = mx.cpu())
  
  mx.exec.update.arg.arrays(P2M_executor, model$arg.params, match.name = TRUE)
  mx.exec.update.aux.arrays(P2M_executor, model$aux.params, match.name = TRUE)
  
  mx.exec.update.arg.arrays(P2M_executor, arg.arrays = list(P2M_img = mx.nd.array(img)), match.name = TRUE)
  mx.exec.forward(P2M_executor, is.train = FALSE)
  P2M_pred_output <- P2M_executor$ref.outputs[[1]]
  
  return(as.array(P2M_pred_output)[,,,1])
  
}

monet_img <- my_predict(model = P2M_gen_model, img = resize_photo)

Show_img(monet_img)

循環對抗生成網路簡介(7)

– 為了解決這個問題,我們還需要額外發展出一個identity mapping loss來解決這個問題!

– 我們依此定義identity mapping loss如下(這邊要注意的是,函數\(G\)原來是負責\(X \rightarrow Y\)的,而函數\(F\)原來則是負責\(Y \rightarrow X\)):

\[ \begin{align} \mbox{identity mapping loss} & = |F(X) - X| + |G(Y) - Y| \end{align} \]

F16_10

循環對抗生成網路簡介(8)

F16_11

P2M_gen_model <- mx.model.load(prefix = 'model/P2M_gen_v2', iteration = 0)

monet_img <- my_predict(model = P2M_gen_model, img = resize_photo)

Show_img(monet_img)

循環對抗生成網路簡介(9)

– 讓我們看看效果吧!這是再過一次的效果(共計2次):

monet_img <- my_predict(model = P2M_gen_model, img = monet_img)

Show_img(monet_img)

– 再過一次試試看(共計3次):

monet_img <- my_predict(model = P2M_gen_model, img = monet_img)

Show_img(monet_img)

循環對抗生成網路實作(1)

F16_12

– 其實不用特別看大概也知道,由於Input size等於Output size,所以它的結構會與 前做Segmentation時的網路類似。

F16_13

循環對抗生成網路實作(2)

# Libraries

library(mxnet)
library(imager)
library(jpeg)
library(magrittr)

# Parameters

CTX <- mx.cpu()
Batch_size <- 1
num_show_img <- 1
n.epoch <- 10
n.print <- 20
w_limit <- 0.1
learning_rate <- 1e-4
lambda_cycle_consistency_loss <- 10
lambda_identity_mapping_loss <- 5
model_name <- 'mini'

循環對抗生成網路實作(3)

# Load data

load('data/mini_train_list.RData')

# Iterator function

my_iterator_core <- function (batch_size) {
  
  batch <-  0
  batch_per_epoch <- floor(length(train_list[[2]])/batch_size)
  
  reset <- function() {batch <<- 0}
  
  iter.next <- function() {
    
    batch <<- batch + 1
    if (batch > batch_per_epoch) {return(FALSE)} else {return(TRUE)}
    
  }
  
  value <- function() {
    
    idx <- 1:batch_size + (batch - 1) * batch_size
    
    img_array.1 <- array(0, dim = c(64, 64, 3, batch_size))
    img_array.2 <- array(0, dim = c(64, 64, 3, batch_size))
    
    for (i in 1:batch_size) {
      
      img_array.1[,,,i] <- readJPEG(train_list[[1]][[idx[i]]])
      img_array.2[,,,i] <- readJPEG(train_list[[2]][[idx[i]]])
      
    }
    
    img_array.1 <- mx.nd.array(img_array.1)
    img_array.2 <- mx.nd.array(img_array.2)
    
    return(list(monet = img_array.1, photo = img_array.2))
    
  }
  
  return(list(reset = reset, iter.next = iter.next, value = value, batch_size = batch_size, batch = batch))
  
}

my_iterator_func <- setRefClass("Custom_Iter",
                                fields = c("iter", "batch_size"),
                                contains = "Rcpp_MXArrayDataIter",
                                methods = list(
                                  initialize = function(iter, batch_size = 16){
                                    .self$iter <- my_iterator_core(batch_size = batch_size)
                                    .self
                                  },
                                  value = function(){
                                    .self$iter$value()
                                  },
                                  iter.next = function(){
                                    .self$iter$iter.next()
                                  },
                                  reset = function(){
                                    .self$iter$reset()
                                  },
                                  finalize=function(){
                                  }
                                )
)

# Build an iterator

my_iter <- my_iterator_func(iter = NULL, batch_size = Batch_size)
# Show image function

Show_img <- function (img) {
  plot(NA, xlim = c(0.04, 0.96), ylim = c(0.96, 0.04), xaxt = "n", yaxt = "n", bty = "n")
  rasterImage(as.raster(img), 0, 1, 1, 0, interpolate = FALSE)
}

# Test the iterator

my_iter$reset()
my_iter$iter.next()
## [1] TRUE
test_data <- my_iter$value()
par(mai = rep(0, 4))
Show_img(as.array(test_data[[1]])[,,,1])

循環對抗生成網路實作(4)

Residual.CONV_Module <- function (indata, num_filters = 128, kernel_size = 3, relu_slope = 0, name = 'g1', stage = 1) {
  
  Conv.1 <- mx.symbol.Convolution(data = indata, kernel = c(kernel_size, kernel_size), stride = c(1, 1),
                                  pad = c((kernel_size - 1)/2, (kernel_size - 1)/2),
                                  no.bias = TRUE, num.filter = num_filters,
                                  name = paste0(name, '_', stage, '_Conv.1'))
  InstNorm.1 <- mx.symbol.InstanceNorm(data = Conv.1, name = paste0(name, '_', stage, '_InstNorm.1'))
  ReLU.1 <- mx.symbol.LeakyReLU(data = InstNorm.1, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU.1'))
  
  Conv.2 <- mx.symbol.Convolution(data = ReLU.1, kernel = c(kernel_size, kernel_size), stride = c(1, 1),
                                  pad = c((kernel_size - 1)/2, (kernel_size - 1)/2),
                                  no.bias = TRUE, num.filter = num_filters,
                                  name = paste0(name, '_', stage, '_Conv.2'))
  InstNorm.2 <- mx.symbol.InstanceNorm(data = Conv.2, name = paste0(name, '_', stage, '_InstNorm.2'))
  ReLU.2 <- mx.symbol.LeakyReLU(data = InstNorm.2, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU.2'))
  ResBlock <- mx.symbol.broadcast_plus(lhs = indata, rhs = ReLU.2, name = paste0(name, '_', stage, '_ResBlock'))
  
  return(ResBlock)
  
}

general.CONV_Module <- function (indata, num_filters = 128, kernel_size = 3, stride = 1, pad = 1, relu_slope = 0, drop_p = 0, name = 'g1', stage = 1, normalization = FALSE) {
  
  Drop <- mx.symbol.Dropout(data = indata, p = drop_p, name = paste0(name, '_', stage, '_Drop'))
  
  if (normalization) {
    
    Conv <- mx.symbol.Convolution(data = Drop, kernel = c(kernel_size, kernel_size), stride = c(stride, stride),
                                  pad = c(pad, pad),
                                  no.bias = TRUE, num.filter = num_filters,
                                  name = paste0(name, '_', stage, '_Conv'))
    InstNorm <- mx.symbol.InstanceNorm(data = Conv, name = paste0(name, '_', stage, '_InstNorm'))
    ReLU <- mx.symbol.LeakyReLU(data = InstNorm, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU'))
    
    return(ReLU)
    
  } else {
    
    Conv <- mx.symbol.Convolution(data = Drop, kernel = c(kernel_size, kernel_size), stride = c(stride, stride),
                                  pad = c(pad, pad),
                                  no.bias = FALSE, num.filter = num_filters,
                                  name = paste0(name, '_', stage, '_Conv'))
    
    return(Conv)
    
  }
  
}

DECONV_Module <- function (indata, updata = NULL, num_filters = 128, relu_slope = 0, name = 'g1', stage = 1) {
  
  DeConv <- mx.symbol.Deconvolution(data = indata, kernel = c(2, 2), stride = c(2, 2),
                                    num_filter = num_filters,
                                    name = paste0(name, '_', stage, '_DeConv'))
  
  InstNorm <- mx.symbol.InstanceNorm(data = DeConv, name = paste0(name, '_', stage, '_InstNorm'))
  ReLU <- mx.symbol.LeakyReLU(data = InstNorm, act.type = 'leaky', slope = relu_slope, name = paste0(name, '_', stage, '_ReLU'))
  
  if (is.null(updata)) {
    return(ReLU)
  } else {
    DenBlock <- mx.symbol.concat(data = list(updata, ReLU), num.args = 2, dim = 1, name = paste0(name, '_', stage, '_DenBlock'))
    return(DenBlock)
  }
  
}

循環對抗生成網路實作(5)

Generator_symbol <- function (name = 'g1') {
  
  g_img <- mx.symbol.Variable(paste0(name, '_img'))
  g_1 <- general.CONV_Module(indata = g_img, num_filters = 8, kernel_size = 7, stride = 1, pad = 3, relu_slope = 0, drop_p = 0, name = name, stage = 1, normalization = TRUE)
  g_2 <- general.CONV_Module(indata = g_1, num_filters = 16, kernel_size = 3, stride = 2, pad = 1, relu_slope = 0, drop_p = 0, name = name, stage = 2, normalization = TRUE)
  g_3 <- general.CONV_Module(indata = g_2, num_filters = 32, kernel_size = 3, stride = 2, pad = 1, relu_slope = 0, drop_p = 0, name = name, stage = 3, normalization = TRUE)
  g_4 <- Residual.CONV_Module(indata = g_3, num_filters = 32, kernel_size = 3, relu_slope = 0, name = name, stage = 4)
  g_5 <- Residual.CONV_Module(indata = g_4, num_filters = 32, kernel_size = 3, relu_slope = 0, name = name, stage = 5)
  g_6 <- DECONV_Module(indata = g_5, updata = g_2, num_filters = 16, relu_slope = 0, name = name, stage = 6)
  g_7 <- DECONV_Module(indata = g_6, updata = g_1, num_filters = 8, relu_slope = 0, name = name, stage = 7)
  g_8 <- general.CONV_Module(indata = g_7, num_filters = 3, kernel_size = 7, stride = 1, pad = 3, relu_slope = 0, drop_p = 0, name = name, stage = 8, normalization = FALSE)
  g_pred <- mx.symbol.Activation(data = g_8, act_type = "sigmoid", name = paste0(name, '_pred'))
  
  return(g_pred)
  
}
Discriminator_symbol <- function (name = 'd1', drop_p = 0) {
  
  d_img <- mx.symbol.Variable(paste0(name, '_img'))
  d_1 <- general.CONV_Module(indata = d_img, num_filters = 8, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 1, normalization = TRUE)
  d_2 <- general.CONV_Module(indata = d_1, num_filters = 16, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 2, normalization = TRUE)
  d_3 <- general.CONV_Module(indata = d_2, num_filters = 32, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 3, normalization = TRUE)
  d_4 <- general.CONV_Module(indata = d_3, num_filters = 64, kernel_size = 4, stride = 2, pad = 0, relu_slope = 0.2, drop_p = drop_p, name = name, stage = 4, normalization = TRUE)
  d_5 <- general.CONV_Module(indata = d_4, num_filters = 1, kernel_size = 1, stride = 1, pad = 0, relu_slope = 0, drop_p = drop_p, name = name, stage = 5, normalization = FALSE)
  d_pred <- mx.symbol.mean(data = d_5, axis = 1:3, keepdims = FALSE, name = paste0(name, '_pred'))
  
  return(d_pred)
  
}
adversarial_loss <- function (pred, label, lambda = 1) {
  
  loss_pos <-  mx.symbol.broadcast_mul(pred, label)
  loss_neg <-  mx.symbol.broadcast_mul(pred, 1 - label)
  loss_mean <- mx.symbol.mean(loss_neg - loss_pos)
  weighted_loss_mean <- loss_mean * lambda
  adversarial_loss <- mx.symbol.MakeLoss(weighted_loss_mean)
  
  return(adversarial_loss)
  
}

cycle_consistency_loss <- function (pred, label, lambda = 10) {
  
  diff_pred_label <- mx.symbol.broadcast_minus(lhs = pred, rhs = label)
  abs_diff_pred_label <- mx.symbol.abs(data = diff_pred_label)
  mean_loss <- mx.symbol.mean(data = abs_diff_pred_label, axis = 0:3, keepdims = FALSE)
  weighted_mean_loss <- mean_loss * lambda
  cycle_consistency_loss <- mx.symbol.MakeLoss(weighted_mean_loss)
  
  return(cycle_consistency_loss)
  
}

identity_mapping_loss <- function (pred, label, lambda = 5) {
  
  diff_pred_label <- mx.symbol.broadcast_minus(lhs = pred, rhs = label)
  abs_diff_pred_label <- mx.symbol.abs(data = diff_pred_label)
  mean_loss <- mx.symbol.mean(data = abs_diff_pred_label, axis = 0:3, keepdims = FALSE)
  weighted_mean_loss <- mean_loss * lambda
  cycle_consistency_loss <- mx.symbol.MakeLoss(weighted_mean_loss)
  
  return(cycle_consistency_loss)
  
}

循環對抗生成網路實作(6)

# Generator-1 (Monet to Photo)

M2P_gen <- Generator_symbol(name = 'M2P')

# Generator-2 (Photo to Monet)

P2M_gen <- Generator_symbol(name = 'P2M')

# Discriminator-1 (Monet)

Monet_dis <- Discriminator_symbol(name = 'Monet', drop_p = 0)

# Discriminator-2 (Photo)

Photo_dis <- Discriminator_symbol(name = 'Photo', drop_p = 0)

# adversarial loss-1 (Monet)

label <- mx.symbol.Variable('label')
Monet_loss <- adversarial_loss(pred = Monet_dis, label = label, lambda = 1)

# adversarial loss-2 (Photo)

label <- mx.symbol.Variable('label')
Photo_loss <- adversarial_loss(pred = Photo_dis, label = label, lambda = 1)

# cycle consistency loss

pred <- mx.symbol.Variable('pred')
label <- mx.symbol.Variable('label')
CC_loss <- cycle_consistency_loss(pred = pred, label = label, lambda = lambda_cycle_consistency_loss)

# identity mapping loss

pred <- mx.symbol.Variable('pred')
label <- mx.symbol.Variable('label')
IM_loss <- identity_mapping_loss(pred = pred, label = label, lambda = lambda_identity_mapping_loss)

循環對抗生成網路實作(7)

M2P_gen_executor <- mx.simple.bind(symbol = M2P_gen,
                                   M2P_img = c(64, 64, 3, Batch_size),
                                   ctx = CTX, grad.req = "write")

P2M_gen_executor <- mx.simple.bind(symbol = P2M_gen,
                                   P2M_img = c(64, 64, 3, Batch_size),
                                   ctx = CTX, grad.req = "write")

Monet_dis_executor <- mx.simple.bind(symbol = Monet_loss,
                                     Monet_img = c(64, 64, 3, Batch_size), label = c(Batch_size),
                                     ctx = CTX, grad.req = "write")

Photo_dis_executor <- mx.simple.bind(symbol = Photo_loss,
                                     Photo_img = c(64, 64, 3, Batch_size), label = c(Batch_size),
                                     ctx = CTX, grad.req = "write")

cycle_consistency_executor <- mx.simple.bind(symbol = CC_loss,
                                             pred = c(64, 64, 3, Batch_size), label = c(64, 64, 3, Batch_size),
                                             ctx = CTX, grad.req = "write")

identity_mapping_executor <- mx.simple.bind(symbol = IM_loss,
                                            pred = c(64, 64, 3, Batch_size), label = c(64, 64, 3, Batch_size),
                                            ctx = CTX, grad.req = "write")
# Initial parameters

mx.set.seed(0)

M2P_gen_arg <- mxnet:::mx.model.init.params(symbol = M2P_gen,
                                            input.shape = list(M2P_img = c(64, 64, 3, Batch_size)),
                                            output.shape = NULL,
                                            initializer = mxnet:::mx.init.normal(0.02),
                                            ctx = CTX)

P2M_gen_arg <- mxnet:::mx.model.init.params(symbol = P2M_gen,
                                            input.shape = list(P2M_img = c(64, 64, 3, Batch_size)),
                                            output.shape = NULL,
                                            initializer = mxnet:::mx.init.normal(0.02),
                                            ctx = CTX)

Monet_dis_arg <- mxnet:::mx.model.init.params(symbol = Monet_loss,
                                              input.shape = list(Monet_img = c(64, 64, 3, Batch_size), label = c(Batch_size)),
                                              output.shape = NULL,
                                              initializer = mxnet:::mx.init.normal(0.02),
                                              ctx = CTX)

Photo_dis_arg <- mxnet:::mx.model.init.params(symbol = Photo_loss,
                                              input.shape = list(Photo_img = c(64, 64, 3, Batch_size), label = c(Batch_size)),
                                              output.shape = NULL,
                                              initializer = mxnet:::mx.init.normal(0.02),
                                              ctx = CTX)

# Update parameters

mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(M2P_gen_executor, M2P_gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(P2M_gen_executor, P2M_gen_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(Monet_dis_executor, Monet_dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(Monet_dis_executor, Monet_dis_arg$aux.params, match.name = TRUE)
mx.exec.update.arg.arrays(Photo_dis_executor, Photo_dis_arg$arg.params, match.name = TRUE)
mx.exec.update.aux.arrays(Photo_dis_executor, Photo_dis_arg$aux.params, match.name = TRUE)

循環對抗生成網路實作(8)

# Optimizers

M2P_gen_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)
P2M_gen_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)

Monet_dis_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)
Photo_dis_optimizer <- mx.opt.create(name = "adam", learning.rate = learning_rate, beta1 = 0, beta2 = 0.9, wd = 0)

# Updaters

M2P_gen_updater <- mx.opt.get.updater(optimizer = M2P_gen_optimizer, weights = M2P_gen_executor$ref.arg.arrays)
P2M_gen_updater <- mx.opt.get.updater(optimizer = P2M_gen_optimizer, weights = P2M_gen_executor$ref.arg.arrays)
Monet_dis_updater <- mx.opt.get.updater(optimizer = Monet_dis_optimizer, weights = Monet_dis_executor$ref.arg.arrays)
Photo_dis_updater <- mx.opt.get.updater(optimizer = Photo_dis_optimizer, weights = Photo_dis_executor$ref.arg.arrays)

循環對抗生成網路實作(9)

# Start to train

for (j in 1:n.epoch) {
  
  current_batch <- 0
  t0 <- Sys.time()
  my_iter$reset()
  
  batch_logger <- list(Monet_adversarial_loss.gen = NULL,
                       Monet_adversarial_loss.fake = NULL,
                       Monet_adversarial_loss.real = NULL,
                       Photo_adversarial_loss.gen = NULL,
                       Photo_adversarial_loss.fake = NULL,
                       Photo_adversarial_loss.real = NULL,
                       Monet_cycle_consistency_loss = NULL,
                       Photo_cycle_consistency_loss = NULL,
                       Monet_identity_mapping_loss = NULL,
                       Photo_identity_mapping_loss = NULL)
  
  while (my_iter$iter.next()) {
    
    my_values <- my_iter$value()
    
    ##################################
    #                                #
    # Cycle consistency loss (Part1) #
    #                                #
    ##################################
    
    # Generator-1 forward (real Monet to fake Photo)
    
    mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = my_values[['monet']]), match.name = TRUE)
    mx.exec.forward(M2P_gen_executor, is.train = TRUE)
    fake.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
    fake.Photo_img <- as.array(fake.Photo_output)
    
    # Generator-2 forward (fake Photo to restored Monet)
    
    mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = fake.Photo_output), match.name = TRUE)
    mx.exec.forward(P2M_gen_executor, is.train = TRUE)
    restored.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
    restored.Monet_img <- as.array(restored.Monet_output)
    
    # Cycle consistency loss (Monet)
    
    mx.exec.update.arg.arrays(cycle_consistency_executor, arg.arrays = list(pred = restored.Monet_output, label = my_values[['monet']]), match.name = TRUE)
    mx.exec.forward(cycle_consistency_executor, is.train = TRUE)
    mx.exec.backward(cycle_consistency_executor)
    
    batch_logger$Monet_cycle_consistency_loss <- c(batch_logger$Monet_cycle_consistency_loss, as.array(cycle_consistency_executor$ref.outputs[[1]]))
    
    # Generator-2 backward
    
    P2M_grads <- cycle_consistency_executor$ref.grad.arrays[['pred']]
    mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
    P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
    
    # Generator-1 backward
    
    M2P_grads <- P2M_gen_executor$ref.grad.arrays[['P2M_img']]
    mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
    M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
    
    #################################
    #                               #
    # Identity mapping loss (Part1) #
    #                               #
    #################################
    
    # Generator-1 forward (real Photo to fake Photo)
    
    mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = my_values[['photo']]), match.name = TRUE)
    mx.exec.forward(M2P_gen_executor, is.train = TRUE)
    mirror.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
    mirror.Photo_img <- as.array(mirror.Photo_output)
    
    # Identity mapping loss (Photo)
    
    mx.exec.update.arg.arrays(identity_mapping_executor, arg.arrays = list(pred = mirror.Photo_output, label = my_values[['photo']]), match.name = TRUE)
    mx.exec.forward(identity_mapping_executor, is.train = TRUE)
    mx.exec.backward(identity_mapping_executor)
    
    batch_logger$Photo_identity_mapping_loss <- c(batch_logger$Photo_identity_mapping_loss, as.array(identity_mapping_executor$ref.outputs[[1]]))
    
    # Generator-1 backward
    
    M2P_grads <- identity_mapping_executor$ref.grad.arrays[['pred']]
    mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
    M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
    
    ############################
    #                          #
    # Adversarial loss (Part1) #
    #                          #
    ############################
    
    # Generator-1 forward (real Monet to fake Photo)
    
    mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = my_values[['monet']]), match.name = TRUE)
    mx.exec.forward(M2P_gen_executor, is.train = TRUE)
    fake.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
    
    # Discriminator-2 fake (Photo)
    
    mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = list(Photo_img = fake.Photo_output, label = mx.nd.array(rep(1, Batch_size))), match.name = TRUE)
    mx.exec.forward(Photo_dis_executor, is.train = TRUE)
    mx.exec.backward(Photo_dis_executor)
    Photo_dis_update_args <- Photo_dis_updater(weight = Photo_dis_executor$ref.arg.arrays, grad = Photo_dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(Photo_dis_executor, Photo_dis_update_args, skip.null = TRUE)
    
    batch_logger$Photo_adversarial_loss.fake <- c(batch_logger$Photo_adversarial_loss.fake, as.array(Photo_dis_executor$ref.outputs[[1]]))
    
    # Discriminator-2 real (Photo)
    
    mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = list(Photo_img = my_values[['photo']], label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
    mx.exec.forward(Photo_dis_executor, is.train = TRUE)
    mx.exec.backward(Photo_dis_executor)
    Photo_dis_update_args <- Photo_dis_updater(weight = Photo_dis_executor$ref.arg.arrays, grad = Photo_dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(Photo_dis_executor, Photo_dis_update_args, skip.null = TRUE)
    
    batch_logger$Photo_adversarial_loss.real <- c(batch_logger$Photo_adversarial_loss.real, as.array(Photo_dis_executor$ref.outputs[[1]]))
    
    # Adversarial loss (Photo)
    
    mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = list(Photo_img = fake.Photo_output, label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
    mx.exec.forward(Photo_dis_executor, is.train = TRUE)
    mx.exec.backward(Photo_dis_executor)
    
    batch_logger$Photo_adversarial_loss.gen <- c(batch_logger$Photo_adversarial_loss.gen, as.array(Photo_dis_executor$ref.outputs[[1]]))
    
    # Generator-1 backward
    
    M2P_grads <- Photo_dis_executor$ref.grad.arrays[['Photo_img']]
    mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
    M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
    
    # Weight clipping (Discriminator-2)
    
    if (!is.null(w_limit)) {
      
      dis_weight_names <- grep('weight', names(Photo_dis_executor$ref.arg.arrays), value = TRUE)
      
      for (k in dis_weight_names) {
        
        current_dis_weight <- Photo_dis_executor$ref.arg.arrays[[k]] %>% as.array()
        current_dis_weight_list <- current_dis_weight %>% mx.nd.array() %>%
          mx.nd.broadcast.minimum(., mx.nd.array(w_limit)) %>%
          mx.nd.broadcast.maximum(., mx.nd.array(-w_limit)) %>%
          list()
        names(current_dis_weight_list) <- k
        mx.exec.update.arg.arrays(Photo_dis_executor, arg.arrays = current_dis_weight_list, match.name = TRUE)
        
      }
      
    }
    
    ##################################
    #                                #
    # Cycle consistency loss (Part2) #
    #                                #
    ##################################
    
    # Generator-2 forward (real Photo to fake Monet)
    
    mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = my_values[['photo']]), match.name = TRUE)
    mx.exec.forward(P2M_gen_executor, is.train = TRUE)
    fake.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
    fake.Monet_img <- as.array(fake.Monet_output)
    
    # Generator-1 forward (fake Monet to restored Photo)
    
    mx.exec.update.arg.arrays(M2P_gen_executor, arg.arrays = list(M2P_img = fake.Monet_output), match.name = TRUE)
    mx.exec.forward(M2P_gen_executor, is.train = TRUE)
    restored.Photo_output <- M2P_gen_executor$ref.outputs[[1]]
    restored.Photo_img <- as.array(restored.Photo_output)
    
    # Cycle consistency loss (Photo)
    
    mx.exec.update.arg.arrays(cycle_consistency_executor, arg.arrays = list(pred = restored.Photo_output, label = my_values[['photo']]), match.name = TRUE)
    mx.exec.forward(cycle_consistency_executor, is.train = TRUE)
    mx.exec.backward(cycle_consistency_executor)
    
    batch_logger$Photo_cycle_consistency_loss <- c(batch_logger$Photo_cycle_consistency_loss, as.array(cycle_consistency_executor$ref.outputs[[1]]))
    
    # Generator-1 backward
    
    M2P_grads <- cycle_consistency_executor$ref.grad.arrays[['pred']]
    mx.exec.backward(M2P_gen_executor, out_grads = M2P_grads)
    M2P_gen_update_args <- M2P_gen_updater(weight = M2P_gen_executor$ref.arg.arrays, grad = M2P_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(M2P_gen_executor, M2P_gen_update_args, skip.null = TRUE)
    
    # Generator-2 backward
    
    P2M_grads <- M2P_gen_executor$ref.grad.arrays[['M2P_img']]
    mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
    P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
    
    #################################
    #                               #
    # Identity mapping loss (Part2) #
    #                               #
    #################################
    
    # Generator-2 forward (real Monet to fake Monet)
    
    mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = my_values[['monet']]), match.name = TRUE)
    mx.exec.forward(P2M_gen_executor, is.train = TRUE)
    mirror.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
    mirror.Monet_img <- as.array(mirror.Monet_output)
    
    # Identity mapping loss (Monet)
    
    mx.exec.update.arg.arrays(identity_mapping_executor, arg.arrays = list(pred = mirror.Monet_output, label = my_values[['monet']]), match.name = TRUE)
    mx.exec.forward(identity_mapping_executor, is.train = TRUE)
    mx.exec.backward(identity_mapping_executor)
    
    batch_logger$Monet_identity_mapping_loss <- c(batch_logger$Monet_identity_mapping_loss, as.array(identity_mapping_executor$ref.outputs[[1]]))
    
    # Generator-2 backward
    
    P2M_grads <- identity_mapping_executor$ref.grad.arrays[['pred']]
    mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
    P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
    
    ############################
    #                          #
    # Adversarial loss (Part2) #
    #                          #
    ############################
    
    # Generator-2 forward (real Photo to fake Monet)
    
    mx.exec.update.arg.arrays(P2M_gen_executor, arg.arrays = list(P2M_img = my_values[['photo']]), match.name = TRUE)
    mx.exec.forward(P2M_gen_executor, is.train = TRUE)
    fake.Monet_output <- P2M_gen_executor$ref.outputs[[1]]
    
    # Discriminator-1 fake (Monet)
    
    mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = list(Monet_img = fake.Monet_output, label = mx.nd.array(rep(1, Batch_size))), match.name = TRUE)
    mx.exec.forward(Monet_dis_executor, is.train = TRUE)
    mx.exec.backward(Monet_dis_executor)
    Monet_dis_update_args <- Monet_dis_updater(weight = Monet_dis_executor$ref.arg.arrays, grad = Monet_dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(Monet_dis_executor, Monet_dis_update_args, skip.null = TRUE)
    
    batch_logger$Monet_adversarial_loss.fake <- c(batch_logger$Monet_adversarial_loss.fake, as.array(Monet_dis_executor$ref.outputs[[1]]))
    
    # Discriminator-1 real (Monet)
    
    mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = list(Monet_img = my_values[['monet']], label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
    mx.exec.forward(Monet_dis_executor, is.train = TRUE)
    mx.exec.backward(Monet_dis_executor)
    Monet_dis_update_args <- Monet_dis_updater(weight = Monet_dis_executor$ref.arg.arrays, grad = Monet_dis_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(Monet_dis_executor, Monet_dis_update_args, skip.null = TRUE)
    
    batch_logger$Monet_adversarial_loss.real <- c(batch_logger$Monet_adversarial_loss.real, as.array(Monet_dis_executor$ref.outputs[[1]]))
    
    # Adversarial loss (Monet)
    
    mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = list(Monet_img = fake.Monet_output, label = mx.nd.array(rep(0, Batch_size))), match.name = TRUE)
    mx.exec.forward(Monet_dis_executor, is.train = TRUE)
    mx.exec.backward(Monet_dis_executor)
    
    batch_logger$Monet_adversarial_loss.gen <- c(batch_logger$Monet_adversarial_loss.gen, as.array(Monet_dis_executor$ref.outputs[[1]]))
    
    # Generator-2 backward
    
    P2M_grads <- Monet_dis_executor$ref.grad.arrays[['Monet_img']]
    mx.exec.backward(P2M_gen_executor, out_grads = P2M_grads)
    P2M_gen_update_args <- P2M_gen_updater(weight = P2M_gen_executor$ref.arg.arrays, grad = P2M_gen_executor$ref.grad.arrays)
    mx.exec.update.arg.arrays(P2M_gen_executor, P2M_gen_update_args, skip.null = TRUE)
    
    # Weight clipping (Discriminator-1)
    
    if (!is.null(w_limit)) {
      
      dis_weight_names <- grep('weight', names(Monet_dis_executor$ref.arg.arrays), value = TRUE)
      
      for (k in dis_weight_names) {
        
        current_dis_weight <- Monet_dis_executor$ref.arg.arrays[[k]] %>% as.array()
        current_dis_weight_list <- current_dis_weight %>% mx.nd.array() %>%
          mx.nd.broadcast.minimum(., mx.nd.array(w_limit)) %>%
          mx.nd.broadcast.maximum(., mx.nd.array(-w_limit)) %>%
          list()
        names(current_dis_weight_list) <- k
        mx.exec.update.arg.arrays(Monet_dis_executor, arg.arrays = current_dis_weight_list, match.name = TRUE)
        
      }
      
    }
    
    ############################
    #                          #
    # Show current performance #
    #                          #
    ############################
    
    if (current_batch %% n.print == 0) {
      
      # Show current images
      
      par(mfrow = c(num_show_img * 2, 4), mar = c(0.1, 0.1, 0.1, 0.1))
      
      for (i in 1:num_show_img) {
        Show_img(img = as.array(my_values[['monet']])[,,,i])
        Show_img(img = as.array(fake.Photo_img)[,,,i])
        Show_img(img = as.array(mirror.Monet_img)[,,,i])
        Show_img(img = as.array(restored.Monet_img)[,,,i])
      }
      
      for (i in 1:num_show_img) {
        Show_img(img = as.array(my_values[['photo']])[,,,i])
        Show_img(img = as.array(fake.Monet_img)[,,,i])
        Show_img(img = as.array(mirror.Photo_img)[,,,i])
        Show_img(img = as.array(restored.Photo_img)[,,,i])
      }
      
      # Show speed
      
      speed_per_batch <- as.numeric(Sys.time() - t0, units = 'secs') / (current_batch + 1)
      
      # Show loss
      
      current_loss <- batch_logger %>% sapply(., mean) %>% formatC(., 4, format = 'f')
      
      message('Epoch [', j, '] Batch [', current_batch, '] loss list (', formatC(speed_per_batch, 2, format = 'f'), ' sec/batch):')
      message(paste(paste(names(current_loss), current_loss, sep = ': '), collapse = '\n'))
      
    }
    
    current_batch <- current_batch + 1
    
  }
  
  # Save models
  
  M2P_gen_model <- list()
  M2P_gen_model$symbol <- M2P_gen
  M2P_gen_model$arg.params <- M2P_gen_executor$ref.arg.arrays[-1]
  M2P_gen_model$aux.params <- M2P_gen_executor$ref.aux.arrays
  class(M2P_gen_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = M2P_gen_model, prefix = paste0('model/CycleGAN_', model_name, '/M2P_gen_', model_name), iteration = j)
  
  P2M_gen_model <- list()
  P2M_gen_model$symbol <- P2M_gen
  P2M_gen_model$arg.params <- P2M_gen_executor$ref.arg.arrays[-1]
  P2M_gen_model$aux.params <- P2M_gen_executor$ref.aux.arrays
  class(P2M_gen_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = P2M_gen_model, prefix = paste0('model/CycleGAN_', model_name, '/P2M_gen_', model_name), iteration = j)
  
  Monet_dis_model <- list()
  Monet_dis_model$symbol <- Monet_dis
  Monet_dis_model$arg.params <- Monet_dis_executor$ref.arg.arrays[-1]
  Monet_dis_model$aux.params <- Monet_dis_executor$ref.aux.arrays
  class(Monet_dis_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = Monet_dis_model, prefix = paste0('model/CycleGAN_', model_name, '/Monet_dis_', model_name), iteration = j)
  
  Photo_dis_model <- list()
  Photo_dis_model$symbol <- Photo_dis
  Photo_dis_model$arg.params <- Photo_dis_executor$ref.arg.arrays[-1]
  Photo_dis_model$aux.params <- Photo_dis_executor$ref.aux.arrays
  class(Photo_dis_model) <- "MXFeedForwardModel"
  
  mx.model.save(model = Photo_dis_model, prefix = paste0('model/CycleGAN_', model_name, '/Photo_dis_', model_name), iteration = j)
  
}

家庭作業:利用CycleGAN實現更多有趣的任務

F16_15

– 當然,想要真的訓練的很好,你可能需要加深加寬Model architecture,以及使用完整的資料集,而這樣需要極大的運算資源,你可能需要使用GPU server。

F16_14

結語

– 對抗生成網路的變化性非常的有趣,尤其像是CycleGAN的概念,我們的研究已經開始往無監督學習的方向前進了,這是讓機器更像人類非常重要的一步!

– 儘管還有一些任務我們沒有帶各位實現,但你應該能想像到「看圖說話」、「聊天對答」等是怎樣做到的了吧?